{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "name": "ImageSegmentation_ModelSubclassing.ipynb",
      "provenance": [],
      "private_outputs": true,
      "collapsed_sections": [],
      "toc_visible": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "cZCM65CBt1CJ"
      },
      "source": [
        "##### Copyright 2019 The TensorFlow Authors.\n",
        "\n",
        "Licensed under the Apache License, Version 2.0 (the \"License\");"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "cellView": "form",
        "colab": {},
        "colab_type": "code",
        "id": "JOgMcEajtkmg"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "rCSP-dbMw88x"
      },
      "source": [
        "# Image segmentation"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "NEWs8JXRuGex"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/examples/blob/master/community/en/ImageSegmentation_ModelSubclassing.ipynb\">\n",
        "    <img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />\n",
        "    Run in Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/examples/tree/master/community/en/ImageSegmentation_ModelSubclassing.ipynb\">\n",
        "    <img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />\n",
        "    View source on GitHub</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "sMP7mglMuGT2"
      },
      "source": [
        "This tutorial focuses on the task of image segmentation, using an encoder-decoder architecture, implemented with model subclassing API.\n",
        "\n",
        "## What is image segmentation?\n",
        "So far you have seen image classification, where the task of the network is to assign a label or class to an input image. However, suppose you want to know where an object is located in the image, the shape of that object, which pixel belongs to which object, etc. In this case you will want to segment the image, i.e., each pixel of the image is given a label. Thus, the task of image segmentation is to train a neural network to output a pixel-wise mask of the image. This helps in understanding the image at a much finer granularity, i.e., the pixel level. Image segmentation has many applications in medical imaging, self-driving cars and satellite imaging, to name a few.\n",
        "\n",
        "The dataset that will be used for this tutorial is the [Oxford-IIIT Pet Dataset](https://www.robots.ox.ac.uk/~vgg/data/pets/), created by Parkhi *et al*. The dataset consists of images, their corresponding labels, and pixel-wise masks. The masks are basically labels for each pixel. Each pixel is given one of three categories :\n",
        "\n",
        "*   Class 1 : Pixel belonging to the pet.\n",
        "*   Class 2 : Pixel bordering the pet.\n",
        "*   Class 3 : None of the above/ Surrounding pixel."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "MQmKthrSBCld"
      },
      "outputs": [],
      "source": [
        "!pip install git+https://github.com/tensorflow/examples.git"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "g87--n2AtyO_"
      },
      "outputs": [],
      "source": [
        "import tensorflow as tf\n",
        "assert tf.__version__.startswith('2')\n",
        "import tensorflow_datasets as tfds\n",
        "tfds.disable_progress_bar()\n",
        "\n",
        "tf.executing_eagerly()\n",
        "\n",
        "from IPython.display import clear_output\n",
        "import matplotlib.pyplot as plt"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "oWe0_rQM4JbC"
      },
      "source": [
        "## Download the Oxford-IIIT Pets dataset\n",
        "\n",
        "The dataset is already included in TensorFlow datasets, all that is needed to do is download it. The segmentation masks are included in version 3.0.0, which is why this particular version is used."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "40ITeStwDwZb"
      },
      "outputs": [],
      "source": [
        "dataset, info = tfds.load('oxford_iiit_pet:3.0.0', with_info=True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "rJcVdj_U4vzf"
      },
      "source": [
        "The following code performs a simple augmentation of flipping an image. In addition,  image is normalized to [0,1]. Finally, as mentioned above the pixels in the segmentation mask are labeled either {1, 2, 3}. For the sake of convenience, let's subtract 1 from the segmentation mask, resulting in labels that are : {0, 1, 2}."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "FD60EbcAQqov"
      },
      "outputs": [],
      "source": [
        "def normalize(input_image, input_mask):\n",
        "  input_image = tf.cast(input_image, tf.float32)/128.0 - 1\n",
        "  input_mask -= 1\n",
        "  return input_image, input_mask"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "2NPlCnBXQwb1"
      },
      "outputs": [],
      "source": [
        "@tf.function\n",
        "def load_image_train(datapoint):\n",
        "  input_image = tf.image.resize(datapoint['image'], (128, 128))\n",
        "  input_mask = tf.image.resize(datapoint['segmentation_mask'], (128, 128))\n",
        "\n",
        "  if tf.random.uniform(()) > 0.5:\n",
        "    input_image = tf.image.flip_left_right(input_image)\n",
        "    input_mask = tf.image.flip_left_right(input_mask)\n",
        "\n",
        "  input_image, input_mask = normalize(input_image, input_mask)\n",
        "\n",
        "  return input_image, input_mask"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "Zf0S67hJRp3D"
      },
      "outputs": [],
      "source": [
        "def load_image_test(datapoint):\n",
        "  input_image = tf.image.resize(datapoint['image'], (128, 128))\n",
        "  input_mask = tf.image.resize(datapoint['segmentation_mask'], (128, 128))\n",
        "\n",
        "  input_image, input_mask = normalize(input_image, input_mask)\n",
        "\n",
        "  return input_image, input_mask"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "65-qHTjX5VZh"
      },
      "source": [
        "The dataset already contains the required splits of test and train and so let's continue to use the same split."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "yHwj2-8SaQli"
      },
      "outputs": [],
      "source": [
        "TRAIN_LENGTH = info.splits['train'].num_examples\n",
        "BATCH_SIZE = 64\n",
        "BUFFER_SIZE = 1000\n",
        "STEPS_PER_EPOCH = TRAIN_LENGTH // BATCH_SIZE"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "39fYScNz9lmo"
      },
      "outputs": [],
      "source": [
        "train = dataset['train'].map(load_image_train, num_parallel_calls=tf.data.experimental.AUTOTUNE)\n",
        "test = dataset['test'].map(load_image_test)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "DeFwFDN6EVoI"
      },
      "outputs": [],
      "source": [
        "train_dataset = train.cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE).repeat()\n",
        "train_dataset = train_dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)\n",
        "test_dataset = test.batch(BATCH_SIZE)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "Xa3gMAE_9qNa"
      },
      "source": [
        "Let's take a look at an image example and it's corresponding mask from the dataset."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "3N2RPAAW9q4W"
      },
      "outputs": [],
      "source": [
        "def display(display_list):\n",
        "  plt.figure(figsize=(15, 15))\n",
        "\n",
        "  title = ['Input Image', 'True Mask', 'Predicted Mask']\n",
        "\n",
        "  for i in range(len(display_list)):\n",
        "    plt.subplot(1, len(display_list), i+1)\n",
        "    plt.title(title[i])\n",
        "    plt.imshow(tf.keras.preprocessing.image.array_to_img(display_list[i]))\n",
        "    plt.axis('off')\n",
        "  plt.show()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "a6u_Rblkteqb"
      },
      "outputs": [],
      "source": [
        "for image, mask in train.take(1):\n",
        "  sample_image, sample_mask = image, mask\n",
        "display([sample_image, sample_mask])\n",
        "\n",
        "print(sample_image.shape)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "FAOe93FRMk3w"
      },
      "source": [
        "## Define the model\n",
        "The model used here consists of an encoder (downsampler) and decoder (upsampler). \n",
        "\n",
        "Here we define encoder and decoder using model subclassing API. Each encoder block is basically a sequence of Conv layers and MaxPooling layers (UpSampling layers for decoder block), plus residual connection.  The output activation of each block is used as input for the next block, where the activation is not only further processed in the sequence of Conv/MaxPooling layers as in a Sequential model, but also processed in residual layers and added to its output.\n",
        "\n",
        "The reason to output three channels is because there are three possible labels for each pixel. Think of this as multi-classification where each pixel is  classified into three classes."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "c6iB4iMvMkX9"
      },
      "outputs": [],
      "source": [
        "OUTPUT_CHANNELS = 3"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "zO_W5A3ke9hZ"
      },
      "outputs": [],
      "source": [
        "from tensorflow.keras import layers\n",
        "\n",
        "class EncoderBlock(tf.keras.Model):\n",
        "  def __init__(self, filter_size):\n",
        "    # initilize instance variables\n",
        "    super(EncoderBlock, self).__init__()\n",
        "    self.filter_size = filter_size\n",
        "    # define layers\n",
        "    self.layer_1 = layers.Activation('relu')\n",
        "    self.layer_2 = layers.SeparableConv2D(self.filter_size, 3, padding='same')\n",
        "    self.layer_3 = layers.BatchNormalization()\n",
        "\n",
        "    self.layer_4 = layers.Activation('relu')\n",
        "    self.layer_5 = layers.SeparableConv2D(self.filter_size, 3, padding='same')\n",
        "    self.layer_6 = layers.BatchNormalization()\n",
        "\n",
        "    self.layer_7 = layers.MaxPooling2D(3, strides=2, padding='same')\n",
        "\n",
        "    # project residual\n",
        "    self.residual_layer = layers.Conv2D(self.filter_size, 1, strides=2, padding='same')\n",
        "  def call(self, inputs):\n",
        "    x = self.layer_1(inputs)\n",
        "    x = self.layer_2(x)\n",
        "    x = self.layer_3(x)\n",
        "    x = self.layer_4(x)\n",
        "    x = self.layer_5(x)\n",
        "    x = self.layer_6(x)\n",
        "    x = self.layer_7(x)\n",
        "    residual = self.residual_layer(inputs)\n",
        "    x = layers.add([x, residual])\n",
        "    return x\n",
        "\n",
        "class DecoderBlock(tf.keras.Model):\n",
        "  def __init__(self, filter_size):\n",
        "    # initilize instance variables\n",
        "    super(DecoderBlock, self).__init__()\n",
        "    self.filter_size = filter_size\n",
        "    # define layers\n",
        "    self.layer_1 = layers.Activation('relu')\n",
        "    self.layer_2 = layers.Conv2DTranspose(self.filter_size, 3, padding='same')\n",
        "    self.layer_3 = layers.BatchNormalization()\n",
        "\n",
        "    self.layer_4 = layers.Activation('relu')\n",
        "    self.layer_5 = layers.Conv2DTranspose(self.filter_size, 3, padding='same')\n",
        "    self.layer_6 = layers.BatchNormalization()\n",
        "\n",
        "    self.layer_7 = layers.UpSampling2D(2)\n",
        "\n",
        "    # project residual\n",
        "    self.residual_layer_1 = layers.UpSampling2D(2)\n",
        "    self.residual_layer_2 = layers.Conv2D(filter_size, 1, padding='same')\n",
        "  def call(self, inputs):\n",
        "    x = self.layer_1(inputs)\n",
        "    x = self.layer_2(x)\n",
        "    x = self.layer_3(x)\n",
        "    x = self.layer_4(x)\n",
        "    x = self.layer_5(x)\n",
        "    x = self.layer_6(x)\n",
        "    x = self.layer_7(x)\n",
        "    residual = self.residual_layer_1(inputs)\n",
        "    residual = self.residual_layer_2(residual)\n",
        "\n",
        "    x = layers.add([x, residual])\n",
        "    return x"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "EGP8hTwhc3Ww"
      },
      "outputs": [],
      "source": [
        "class ImageSegmentationModel(tf.keras.Model):\n",
        "  def __init__(self, output_channels, dynamic=True):\n",
        "    # initilize instance variables\n",
        "    super(ImageSegmentationModel, self).__init__()\n",
        "    self.output_channels = output_channels\n",
        "\n",
        "    self.entry_block_1 = layers.Conv2D(32, 3, strides=2, padding='same')\n",
        "    self.entry_block_2 = layers.BatchNormalization()\n",
        "    self.entry_block_3 = layers.Activation('relu')\n",
        "\n",
        "    self.encoder_block_1 = EncoderBlock(64)\n",
        "    self.encoder_block_2 = EncoderBlock(128)\n",
        "    self.encoder_block_3 = EncoderBlock(256)\n",
        "\n",
        "    self.decoder_block_1 = DecoderBlock(256)\n",
        "    self.decoder_block_2 = DecoderBlock(128)\n",
        "    self.decoder_block_3 = DecoderBlock(64)\n",
        "    self.decoder_block_4 = DecoderBlock(32)\n",
        "\n",
        "    self.output_layer = layers.Conv2D(\n",
        "        output_channels, 3, activation='sigmoid', padding='same')\n",
        "    \n",
        "  def call(self, inputs):\n",
        "    x = self.entry_block_1(inputs)\n",
        "    x = self.entry_block_2(x)\n",
        "    x = self.entry_block_3(x)\n",
        "    x = self.encoder_block_1(x)\n",
        "    x = self.encoder_block_2(x)\n",
        "    x = self.encoder_block_3(x)\n",
        "    x = self.decoder_block_1(x)\n",
        "    x = self.decoder_block_2(x)\n",
        "    x = self.decoder_block_3(x)\n",
        "    x = self.decoder_block_4(x)\n",
        "    x = self.output_layer(x)\n",
        "    return x"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "j0DGH_4T0VYn"
      },
      "source": [
        "## Train the model\n",
        "Now, all that is left to do is to compile and train the model. The loss used here is losses.sparse_categorical_crossentropy. The reason to use this loss function is that the network is trying to assign each pixel a label, just like multi-class prediction. In the true segmentation mask, each pixel has either a {0,1,2}. The network here is outputting three channels. Essentially, each channel is trying to learn to predict a class, and losses.sparse_categorical_crossentropy is the recommended loss for such a scenario. Using the output of the network, the label assigned to the pixel is the channel with the highest value. This is what the create_mask function is doing."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "6he36HK5uKAc"
      },
      "outputs": [],
      "source": [
        "model = ImageSegmentationModel(OUTPUT_CHANNELS)\n",
        "model.compile(optimizer='adam', loss='sparse_categorical_crossentropy',\n",
        "              metrics=['accuracy'])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "Tc3MiEO2twLS"
      },
      "source": [
        "Let's try out the model to see what it predicts before training."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "UwvIKLZPtxV_"
      },
      "outputs": [],
      "source": [
        "def create_mask(pred_mask):\n",
        "  pred_mask = tf.argmax(pred_mask, axis=-1)\n",
        "  pred_mask = pred_mask[..., tf.newaxis]\n",
        "  return pred_mask[0]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "YLNsrynNtx4d"
      },
      "outputs": [],
      "source": [
        "def show_predictions(dataset=None, num=1):\n",
        "  if dataset:\n",
        "    for image, mask in dataset.take(num):\n",
        "      pred_mask = model.predict(image)\n",
        "      display([image[0], mask[0], create_mask(pred_mask)])\n",
        "  else:\n",
        "    display([sample_image, sample_mask,\n",
        "             create_mask(model.predict(sample_image[tf.newaxis, ...]))])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "X_1CC0T4dho3"
      },
      "outputs": [],
      "source": [
        "show_predictions()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "22AyVYWQdkgk"
      },
      "source": [
        "Let's observe how the model improves while it is training. To accomplish this task, a callback function is defined below. Since in this model, we did not use a pretrained model as Encoder, so the model has to learn everything from scratch. As you can see, in the first a few epochs, the model cannot really predict the mask - a blank mask was predicted. Only after about 10 epochs, the model prediction started to show something that makes sense."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "wHrHsqijdmL6"
      },
      "outputs": [],
      "source": [
        "class DisplayCallback(tf.keras.callbacks.Callback):\n",
        "  def on_epoch_end(self, epoch, logs=None):\n",
        "    clear_output(wait=True)\n",
        "    show_predictions()\n",
        "    print ('\\nSample Prediction after epoch {}\\n'.format(epoch+1))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "StKDH_B9t4SD"
      },
      "outputs": [],
      "source": [
        "EPOCHS = 32\n",
        "VAL_SUBSPLITS = 5\n",
        "VALIDATION_STEPS = info.splits['test'].num_examples//BATCH_SIZE//VAL_SUBSPLITS\n",
        "\n",
        "model_history = model.fit(train_dataset, epochs=EPOCHS,\n",
        "                          steps_per_epoch=STEPS_PER_EPOCH,\n",
        "                          validation_steps=VALIDATION_STEPS,\n",
        "                          validation_data=test_dataset,\n",
        "                          callbacks=[DisplayCallback()])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "P_mu0SAbt40Q"
      },
      "outputs": [],
      "source": [
        "loss = model_history.history['loss']\n",
        "val_loss = model_history.history['val_loss']\n",
        "\n",
        "epochs = range(EPOCHS)\n",
        "\n",
        "plt.figure()\n",
        "plt.plot(epochs, loss, 'r', label='Training loss')\n",
        "plt.plot(epochs, val_loss, 'bo', label='Validation loss')\n",
        "plt.title('Training and Validation Loss')\n",
        "plt.xlabel('Epoch')\n",
        "plt.ylabel('Loss Value')\n",
        "plt.ylim([0, 2])\n",
        "plt.legend()\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "unP3cnxo_N72"
      },
      "source": [
        "## Make predictions"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "7BVXldSo-0mW"
      },
      "source": [
        "Let's make some predictions. In the interest of saving time, the number of epochs was kept small, but you may set this higher to achieve more accurate results."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "ikrzoG24qwf5"
      },
      "outputs": [],
      "source": [
        "show_predictions(test_dataset, 3)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "R24tahEqmSCk"
      },
      "source": [
        "## Next steps\n",
        "Now that you have an understanding of what image segmentation is and how it works, you can try this tutorial out with different intermediate layer outputs, or even different pretrained model. You may also challenge yourself by trying out the [Carvana](https://www.kaggle.com/c/carvana-image-masking-challenge/overview) image masking challenge hosted on Kaggle.\n",
        "\n",
        "You may also want to explore the [UNet model](https://arxiv.org/pdf/1505.04597.pdf), which has cross connections between encoder and decoder block. These cross connections allow better high resolution localization since the image doesn't have to go through all the the down-sampling steps to get to the output."
      ]
    }
  ]
}
